/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.socket.client.standard;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Callable;

import javax.websocket.ClientEndpointConfig;
import javax.websocket.ClientEndpointConfig.Configurator;
import javax.websocket.ContainerProvider;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.WebSocketContainer;

import org.springframework.core.task.AsyncListenableTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureTask;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
import org.springframework.web.socket.client.AbstractWebSocketClient;

/**
 * A WebSocketClient based on standard Java WebSocket API.
 *
 * @author Rossen Stoyanchev
 * @since 4.0
 */
public class StandardWebSocketClient extends AbstractWebSocketClient {

    private final WebSocketContainer webSocketContainer;

    private final Map<String, Object> userProperties = new HashMap<>();

    @Nullable
    private AsyncListenableTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor();


    /**
     * Default constructor that calls {@code ContainerProvider.getWebSocketContainer()}
     * to obtain a (new) {@link WebSocketContainer} instance. Also see constructor
     * accepting existing {@code WebSocketContainer} instance.
     */
    public StandardWebSocketClient() {
        this.webSocketContainer = ContainerProvider.getWebSocketContainer();
    }

    /**
     * Constructor accepting an existing {@link WebSocketContainer} instance.
     * <p>For XML configuration, see {@link WebSocketContainerFactoryBean}. For Java
     * configuration, use {@code ContainerProvider.getWebSocketContainer()} to obtain
     * the {@code WebSocketContainer} instance.
     */
    public StandardWebSocketClient(WebSocketContainer webSocketContainer) {
        Assert.notNull(webSocketContainer, "WebSocketContainer must not be null");
        this.webSocketContainer = webSocketContainer;
    }

    private static List<Extension> adaptExtensions(List<WebSocketExtension> extensions) {
        List<Extension> result = new ArrayList<>();
        for (WebSocketExtension extension : extensions) {
            result.add(new WebSocketToStandardExtensionAdapter(extension));
        }
        return result;
    }

    /**
     * The configured user properties.
     */
    public Map<String, Object> getUserProperties() {
        return this.userProperties;
    }

    /**
     * The standard Java WebSocket API allows passing "user properties" to the
     * server via {@link ClientEndpointConfig#getUserProperties() userProperties}.
     * Use this property to configure one or more properties to be passed on
     * every handshake.
     */
    public void setUserProperties(@Nullable Map<String, Object> userProperties) {
        if (userProperties != null) {
            this.userProperties.putAll(userProperties);
        }
    }

    /**
     * Return the configured {@link TaskExecutor}.
     */
    @Nullable
    public AsyncListenableTaskExecutor getTaskExecutor() {
        return this.taskExecutor;
    }

    /**
     * Set an {@link AsyncListenableTaskExecutor} to use when opening connections.
     * If this property is set to {@code null}, calls to any of the
     * {@code doHandshake} methods will block until the connection is established.
     * <p>By default, an instance of {@code SimpleAsyncTaskExecutor} is used.
     */
    public void setTaskExecutor(@Nullable AsyncListenableTaskExecutor taskExecutor) {
        this.taskExecutor = taskExecutor;
    }

    /**
     * 握手实现
     * @param webSocketHandler the client-side handler for WebSocket messages
     * @param headers          the HTTP headers to use for the handshake, with unwanted (forbidden)
     *                         headers filtered out (never {@code null})
     * @param uri              the target URI for the handshake (never {@code null})
     * @param protocols
     * @param extensions       requested WebSocket extensions, or an empty list
     * @param attributes       attributes to associate with the WebSocketSession, i.e. via
     *                         {@link WebSocketSession#getAttributes()}; currently always an empty map.
     * @return
     */
    @Override
    protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
                                                                     HttpHeaders headers, final URI uri, List<String> protocols,
                                                                     List<WebSocketExtension> extensions, Map<String, Object> attributes) {

        // 获取端口
        int port = getPort(uri);
        // 地址获取
        InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), port);
        InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), port);

        // 创建链接对象
        final StandardWebSocketSession session = new StandardWebSocketSession(headers,
                attributes, localAddress, remoteAddress);

        // 连接信息
        final ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
                .configurator(new StandardWebSocketClientConfigurator(headers))
                .preferredSubprotocols(protocols)
                .extensions(adaptExtensions(extensions)).build();

        endpointConfig.getUserProperties().putAll(getUserProperties());

        final Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);

        Callable<WebSocketSession> connectTask = () -> {
            this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
            return session;
        };

        if (this.taskExecutor != null) {
            return this.taskExecutor.submitListenable(connectTask);
        }
        else {
            ListenableFutureTask<WebSocketSession> task = new ListenableFutureTask<>(connectTask);
            // 运行链接任务
            task.run();
            return task;
        }
    }

    /**
     * 获取host
     * @return {@link InetAddress} host
     */
    private InetAddress getLocalHost() {
        try {
            // 获取host
            return InetAddress.getLocalHost();
        }
        catch (UnknownHostException ex) {
            return InetAddress.getLoopbackAddress();
        }
    }

    /**
     * 获取端口号
     * @param uri
     * @return
     */
    private int getPort(URI uri) {
        // 根据协议返回端口
        if (uri.getPort() == -1) {
            String scheme = uri.getScheme().toLowerCase(Locale.ENGLISH);
            return ("wss".equals(scheme) ? 443 : 80);
        }
        return uri.getPort();
    }


    /**
     * 配置信息
     */
    private class StandardWebSocketClientConfigurator extends Configurator {

        /**
         * http 头信息
         */
        private final HttpHeaders headers;

        public StandardWebSocketClientConfigurator(HttpHeaders headers) {
            this.headers = headers;
        }

        @Override
        public void beforeRequest(Map<String, List<String>> requestHeaders) {
            requestHeaders.putAll(this.headers);
            if (logger.isTraceEnabled()) {
                logger.trace("Handshake request headers: " + requestHeaders);
            }
        }

        @Override
        public void afterResponse(HandshakeResponse response) {
            if (logger.isTraceEnabled()) {
                logger.trace("Handshake response headers: " + response.getHeaders());
            }
        }
    }

}
